#!/usr/bin/env python3
import pandas as pd, numpy as np, sys

KEY = ["stack_id","R_G_bin","Mstar_bin","b"]

def die(msg):
    sys.stderr.write(str(msg) + "\n"); sys.exit(1)

def main(lens_path="data/prestacked_stacks_lens.csv",
         rand_path="data/prestacked_stacks_rand.csv",
         out_path="data/prestacked_stacks.csv",
         lens_meta="data/prestacked_meta_lens.csv",
         out_meta="data/prestacked_meta.csv"):
    try:
        lens = pd.read_csv(lens_path)
    except Exception as e:
        die(f"Failed to read lens stacks: {lens_path} ({e})")
    try:
        rand = pd.read_csv(rand_path)
    except Exception as e:
        die(f"Failed to read random stacks: {rand_path} ({e})")

    need_lens = set(KEY + ["gamma_t","weight"])
    need_rand = set(KEY + ["gamma_t"])
    if not need_lens.issubset(lens.columns):
        missing = list(need_lens - set(lens.columns))
        die(f"Lens stacks missing columns: {missing}. Have: {list(lens.columns)}")
    if not need_rand.issubset(rand.columns):
        missing = list(need_rand - set(rand.columns))
        die(f"Random stacks missing columns: {missing}. Have: {list(rand.columns)}")

    # Left-join lens with randoms and subtract per bin
    merged = lens.merge(
        rand[KEY + ["gamma_t"]].rename(columns={"gamma_t":"gamma_t_rand"}),
        on=KEY, how="left", copy=False
    )
    merged["gamma_t_rand"] = merged["gamma_t_rand"].fillna(0.0)
    merged["gamma_t"] = merged["gamma_t"] - merged["gamma_t_rand"]

    # Keep lens weights; ensure uniqueness; sort for readability
    weights = lens[KEY + ["weight"]].drop_duplicates()
    out = merged[KEY + ["gamma_t"]].merge(weights, on=KEY, how="left").drop_duplicates()
    out = out.sort_values(["stack_id","b"]).reset_index(drop=True)

    out.to_csv(out_path, index=False)

    # Copy lens meta as the run meta (n_lenses, etc.)
    try:
        meta = pd.read_csv(lens_meta)
        meta.to_csv(out_meta, index=False)
    except Exception as e:
        print(f"Warning: could not copy meta: {e}", file=sys.stderr)

    # Quick summary
    print(f"Wrote {out_path} with {len(out)} rows across {out['stack_id'].nunique()} stacks.")
    print(f"Saved meta to {out_meta}.")

if __name__ == "__main__":
    import argparse
    p = argparse.ArgumentParser(description="Subtract random stacks from lens stacks per bin.")
    p.add_argument("--lens", default="data/prestacked_stacks_lens.csv")
    p.add_argument("--rand", default="data/prestacked_stacks_rand.csv")
    p.add_argument("--out",  default="data/prestacked_stacks.csv")
    p.add_argument("--lens-meta", default="data/prestacked_meta_lens.csv")
    p.add_argument("--out-meta",  default="data/prestacked_meta.csv")
    a = p.parse_args()
    main(a.lens, a.rand, a.out, a.lens_meta, a.out_meta)
